#!/usr/bin/env python
"""Namelist creator for CIME's driver.
"""
import os, sys

_CIMEROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..","..","..","..")
sys.path.append(os.path.join(_CIMEROOT, "scripts", "Tools"))

import shutil, glob, itertools
from standard_script_setup import *
from CIME.case import Case
from CIME.nmlgen import NamelistGenerator
from CIME.utils import expect
from CIME.utils import get_model, get_time_in_seconds, get_timestamp
from CIME.buildnml import create_namelist_infile, parse_input
from CIME.XML.files import Files
#pylint: disable=undefined-variable
logger = logging.getLogger(__name__)

###############################################################################
def _create_drv_namelists(case, infile, confdir, nmlgen, files):
###############################################################################

    #--------------------------------
    # Set up config dictionary
    #--------------------------------
    config = {}
    cime_model = get_model()
    config['cime_model'] = cime_model
    config['iyear'] = case.get_value('COMPSET').split('_')[0]
    config['BGC_MODE'] = case.get_value("CCSM_BGC")
    config['CPL_I2O_PER_CAT'] = case.get_value('CPL_I2O_PER_CAT')
    config['COMP_RUN_BARRIERS'] = case.get_value('COMP_RUN_BARRIERS')
    config['DRV_THREADING'] = case.get_value('DRV_THREADING')
    config['CPL_ALBAV'] = case.get_value('CPL_ALBAV')
    config['CPL_EPBAL'] = case.get_value('CPL_EPBAL')
    config['FLDS_WISO'] = case.get_value('FLDS_WISO')
    config['BUDGETS'] = case.get_value('BUDGETS')
    config['MACH'] = case.get_value('MACH')
    config['MPILIB'] = case.get_value('MPILIB')
    config['OS'] = case.get_value('OS')
    config['glc_nec'] = 0 if case.get_value('GLC_NEC') == 0 else case.get_value('GLC_NEC')
    config['single_column'] = 'true' if case.get_value('PTS_MODE') else 'false'
    config['timer_level'] = 'pos' if case.get_value('TIMER_LEVEL') >= 1 else 'neg'
    config['bfbflag'] = 'on' if case.get_value('BFBFLAG') else 'off'
    config['continue_run'] = '.true.' if case.get_value('CONTINUE_RUN') else '.false.'
    config['flux_epbal'] = 'ocn' if case.get_value('CPL_EPBAL')  == 'ocn' else 'off'
    config['atm_grid'] = case.get_value('ATM_GRID')

    # needed for determining the run sequence as well as glc_renormalize_smb
    config['COMP_ATM'] = case.get_value("COMP_ATM")
    config['COMP_ICE'] = case.get_value("COMP_ICE")
    config['COMP_GLC'] = case.get_value("COMP_GLC")
    config['COMP_LND'] = case.get_value("COMP_LND")
    config['COMP_OCN'] = case.get_value("COMP_OCN")
    config['COMP_ROF'] = case.get_value("COMP_ROF")
    config['COMP_WAV'] = case.get_value("COMP_WAV")

    if ((case.get_value("COMP_ROF") == 'mosart' and case.get_value("MOSART_MODE") == 'NULL') or
        (case.get_value("COMP_ROF") == 'rtm' and case.get_value("RTM_MODE") == 'NULL') or
        (case.get_value("ROF_GRID") == 'null')):
        config['ROF_MODE'] = 'null'

    if case.get_value('RUN_TYPE') == 'startup':
        config['run_type'] = 'startup'
    elif case.get_value('RUN_TYPE') == 'hybrid':
        config['run_type'] = 'startup'
    elif case.get_value('RUN_TYPE') == 'branch':
        config['run_type'] = 'branch'

    #----------------------------------------------------
    # Initialize namelist defaults
    #----------------------------------------------------
    nmlgen.init_defaults(infile, config)

    if case.get_value('MEDIATOR_READ_RESTART'):
        nmlgen.set_value('mediator_read_restart', value='.true.')
    else:
        nmlgen.set_value('mediator_read_restart', value='.false.')

    #--------------------------------
    # Overwrite: set brnch_retain_casename
    #--------------------------------
    start_type = nmlgen.get_value('start_type')
    if start_type != 'startup':
        if case.get_value('CASE') == case.get_value('RUN_REFCASE'):
            nmlgen.set_value('brnch_retain_casename' , value='.true.')

    # set aquaplanet if appropriate
    if config['COMP_OCN'] == 'docn' and 'aqua' in case.get_value("DOCN_MODE"):
        nmlgen.set_value('aqua_planet' , value='.true.')

    #--------------------------------
    # Overwrite: set component coupling frequencies
    #--------------------------------
    ncpl_base_period  = case.get_value('NCPL_BASE_PERIOD')
    if ncpl_base_period == 'hour':
        basedt = 3600
    elif ncpl_base_period == 'day':
        basedt = 3600 * 24
    elif ncpl_base_period == 'year':
        if case.get_value('CALENDAR') == 'NO_LEAP':
            basedt = 3600 * 24 * 365
        else:
            expect(False, "Invalid CALENDAR for NCPL_BASE_PERIOD %s " %ncpl_base_period)
    elif ncpl_base_period == 'decade':
        if case.get_value('CALENDAR') == 'NO_LEAP':
            basedt = 3600 * 24 * 365 * 10
        else:
            expect(False, "invalid NCPL_BASE_PERIOD NCPL_BASE_PERIOD %s " %ncpl_base_period)
    else:
        expect(False, "invalid NCPL_BASE_PERIOD NCPL_BASE_PERIOD %s " %ncpl_base_period)

    if basedt < 0:
        expect(False, "basedt invalid overflow for NCPL_BASE_PERIOD %s " %ncpl_base_period)


    # determine coupling intervals
    comps = case.get_values("COMP_CLASSES")
    mindt = basedt
    coupling_times = {}
    for comp in comps:
        ncpl = case.get_value(comp.upper() + '_NCPL')
        if ncpl is not None:
            cpl_dt = basedt // int(ncpl)
            totaldt = cpl_dt * int(ncpl)
            if totaldt != basedt:
                expect(False, " %s ncpl doesn't divide base dt evenly" %comp)
            nmlgen.add_default(comp.lower() + '_cpl_dt', value=cpl_dt)
            coupling_times[comp.lower() + '_cpl_dt'] = cpl_dt
            mindt = min(mindt, cpl_dt)

    # sanity check
    comp_atm = case.get_value("COMP_ATM")
    if comp_atm is not None and comp_atm not in('datm', 'xatm', 'satm'):
        atmdt = int(basedt / case.get_value('ATM_NCPL'))
        expect(atmdt == mindt, 'Active atm should match shortest model timestep atmdt={} mindt={}'
               .format(atmdt, mindt))

    #--------------------------------
    # Overwrite: set start_ymd
    #--------------------------------
    run_startdate = "".join(str(x) for x in case.get_value('RUN_STARTDATE').split('-'))
    nmlgen.set_value('start_ymd', value=run_startdate)

    #--------------------------------
    # Overwrite: set tprof_option and tprof_n - if tprof_total is > 0
    #--------------------------------
    # This would be better handled inside the alarm logic in the driver routines.
    # Here supporting only nday(s), nmonth(s), and nyear(s).

    stop_option = case.get_value('STOP_OPTION')
    if 'nyear' in stop_option:
        tprofoption = 'ndays'
        tprofmult = 365
    elif 'nmonth' in stop_option:
        tprofoption = 'ndays'
        tprofmult = 30
    elif 'nday' in stop_option:
        tprofoption = 'ndays'
        tprofmult = 1
    else:
        tprofmult = 1
        tprofoption = 'never'

    tprof_total = case.get_value('TPROF_TOTAL')
    if ((tprof_total > 0) and (case.get_value('STOP_DATE') < 0) and ('ndays' in tprofoption)):
        stop_n = case.get_value('STOP_N')
        stopn = tprofmult * stop_n
        tprofn = int(stopn / tprof_total)
        if tprofn < 1:
            tprofn = 1
        nmlgen.set_value('tprof_option', value=tprofoption)
        nmlgen.set_value('tprof_n'     , value=tprofn)

    # Set up the pause_component_list if pause is active
    pauseo = case.get_value('PAUSE_OPTION')
    if pauseo != 'never' and pauseo != 'none':
        pausen = case.get_value('PAUSE_N')
        pcl = nmlgen.get_default('pause_component_list')
        nmlgen.add_default('pause_component_list', pcl)
        # Check to make sure pause_component_list is valid
        pcl = nmlgen.get_value('pause_component_list')
        if pcl != 'none' and pcl != 'all':
            pause_comps = pcl.split(':')
            comp_classes = case.get_values("COMP_CLASSES")
            for comp in pause_comps:
                expect(comp == 'drv' or comp.upper() in comp_classes,
                       "Invalid PAUSE_COMPONENT_LIST, %s is not a valid component type"%comp)
            # End for
        # End if
        # Set esp interval
        if 'nstep' in pauseo:
            esp_time = mindt
        else:
            esp_time = get_time_in_seconds(pausen, pauseo)

        nmlgen.set_value('esp_cpl_dt', value=esp_time)
    # End if pause is active

    #--------------------------------
    # (1) Specify input data list file
    #--------------------------------
    data_list_path = os.path.join(case.get_case_root(), "Buildconf", "cpl.input_data_list")
    if os.path.exists(data_list_path):
        os.remove(data_list_path)

    #--------------------------------
    # (2) Write namelist file drv_in and initial input dataset list.
    #--------------------------------
    namelist_file = os.path.join(confdir, "drv_in")
    drv_namelist_groups = ["papi_inparm", "pio_default_inparm", "prof_inparm", "debug_inparm"]
    nmlgen.write_output_file(namelist_file, data_list_path=data_list_path, groups=drv_namelist_groups)

    #--------------------------------
    # (3) Write nuopc.runconfig file and add to input dataset list.
    #--------------------------------

    # Determine valid components
    valid_comps = []
    for item in case.get_values("COMP_CLASSES"):
        comp = case.get_value("COMP_" + item)
        valid = True
        # stub comps
        if comp == 's' + item.lower():
            valid = False
        # xcpl_comps
        elif comp == 'x' + item.lower():
            if item != 'ESP': #no esp xcpl component
                if case.get_value(item + "_NX") == "0" and case.get_value(item + "_NY") == "0":
                    valid = False
        # special case - mosart or rtm in NULL mode
        elif (case.get_value("COMP_ROF") == 'mosart' or case.get_value("COMP_ROF") == 'rtm'):
            if (case.get_value("MOSART_MODE") == 'NULL' or case.get_value("RTM_MODE") == 'NULL'):
                valid = False
        if valid:
            valid_comps.append(item)

    # set the driver rpointer file if there is only one non-stub component then skip mediator 
   
    if len(valid_comps) == 2 and  "dwav" not in case.get_value("COMP_WAV") and "dlnd" not in case.get_value("COMP_LND"):
        skip_mediator = True
        valid_comps.remove("CPL")
        nmlgen.set_value('mediator_present', value='.false.')
        nmlgen.set_value("drv_restart_pointer", value="none")
        nmlgen.set_value("component_list", value=" ".join(valid_comps))
    else:
        skip_mediator = False
        nmlgen.set_value("drv_restart_pointer", value="rpointer.cpl")
        valid_comps_string = " ".join(valid_comps)
        nmlgen.set_value("component_list", value=valid_comps_string.replace("CPL","MED"))

    logger.info("Writing nuopc_runseq for components {}".format(valid_comps))
    nuopc_config_file = os.path.join(confdir, "nuopc.runconfig")
    nmlgen.write_nuopc_config_file(nuopc_config_file, data_list_path=data_list_path) 

    #--------------------------------
    # (4) Write nuopc.runseq
    #--------------------------------
    _create_runseq(case, coupling_times, valid_comps)

    #--------------------------------
    # (5) Write drv_flds_in
    #--------------------------------
    # In thte following, all values come simply from the infiles - no default values need to be added
    # FIXME - do want to add the possibility that will use a user definition file for drv_flds_in

    caseroot = case.get_value('CASEROOT')
    namelist_file = os.path.join(confdir, "drv_flds_in")
    nmlgen.add_default('drv_flds_in_files')
    drvflds_files = nmlgen.get_default('drv_flds_in_files')
    infiles = []
    for drvflds_file in drvflds_files:
        infile = os.path.join(caseroot, drvflds_file)
        if os.path.isfile(infile):
            infiles.append(infile)

    if len(infiles) != 0:

        # First read the drv_flds_in files and make sure that
        # for any key there are not two conflicting values
        dicts = {}
        for infile in infiles:
            dict_ = {}
            with open(infile) as myfile:
                for line in myfile:
                    if "=" in line and '!' not in line:
                        name, var = line.partition("=")[::2]
                        name = name.strip()
                        var = var.strip()
                        dict_[name] = var
            dicts[infile] = dict_

        for first,second in itertools.combinations(dicts.keys(),2):
            compare_drv_flds_in(dicts[first], dicts[second], first, second)

        # Now create drv_flds_in
        config = {}
        definition_dir = os.path.dirname(files.get_value("NAMELIST_DEFINITION_FILE", attribute={"component":"drv"}))
        definition_file = [os.path.join(definition_dir, "namelist_definition_drv_flds.xml")]
        nmlgen = NamelistGenerator(case, definition_file, files=files)
        skip_entry_loop = True
        nmlgen.init_defaults(infiles, config, skip_entry_loop=skip_entry_loop)
        drv_flds_in = os.path.join(caseroot, "CaseDocs", "drv_flds_in")
        nmlgen.write_output_file(drv_flds_in)

###############################################################################
def _create_runseq(case, coupling_times, valid_comps):
###############################################################################

    caseroot  = case.get_value("CASEROOT")
    user_file = os.path.join(caseroot, "nuopc.runseq")
    rundir = case.get_value("RUNDIR")

    if os.path.exists(user_file):

        # Determine if there is a user run sequence file in CASEROOT, use it
        shutil.copy(user_file, rundir)
        shutil.copy(user_file, os.path.join(caseroot,"CaseDocs"))
        logger.info("NUOPC run sequence: copying custom run sequence from case root")

    else:

        if len(valid_comps) == 1: 

            # Create run sequence with no mediator
            outfile = open(os.path.join(caseroot, "CaseDocs", "nuopc.runseq"), "w")
            dtime = coupling_times[valid_comps[0].lower() + '_cpl_dt']
            outfile.write ("runSeq:: \n")
            outfile.write ("@" + str(dtime) + " \n")
            outfile.write ("  " + valid_comps[0] + " \n")
            outfile.write ("@  \n")
            outfile.write (":: \n")
            outfile.close()
            shutil.copy(os.path.join(caseroot, "CaseDocs", "nuopc.runseq"), rundir)
            
        else:

            # Create a run sequence file appropriate for target compset
            comp_atm = case.get_value("COMP_ATM")
            comp_ice = case.get_value("COMP_ICE")
            comp_glc = case.get_value("COMP_GLC")
            comp_lnd = case.get_value("COMP_LND")
            comp_ocn = case.get_value("COMP_OCN")

            sys.path.append(os.path.join(_CIMEROOT, "src", "drivers", "nuopc", "cime_config", "runseq"))

            if (comp_ice == "cice" and comp_atm == 'datm' and comp_ocn == "docn"):
                from runseq_D import gen_runseq
            elif (comp_lnd == 'dlnd' and comp_glc == "cism"):
                from runseq_TG import gen_runseq
            elif (comp_atm == 'ufsatm' and comp_ocn == "mom" and comp_ice == 'cice'):
                from runseq_NEMS import gen_runseq
            else: 
                from runseq_general import gen_runseq

            # create the run sequence
            gen_runseq(case, coupling_times)

###############################################################################
def compare_drv_flds_in(first, second, infile1, infile2):
###############################################################################
    sharedKeys = set(first.keys()).intersection(second.keys())
    for key in sharedKeys:
        if first[key] != second[key]:
            print('Key: {}, \n Value 1: {}, \n Value 2: {}'.format(key, first[key], second[key]))
            expect(False, "incompatible settings in drv_flds_in from \n %s \n and \n %s"
                   % (infile1, infile2))

###############################################################################
def _create_component_modelio_namelists(confdir, case, files):
###############################################################################

    # will need to create a new namelist generator
    infiles = []
    definition_dir = os.path.dirname(files.get_value("NAMELIST_DEFINITION_FILE", attribute={"component":"drv"}))
    definition_file = [os.path.join(definition_dir, "namelist_definition_modelio.xml")]

    confdir = os.path.join(case.get_value("CASEBUILD"), "cplconf")
    lid = os.environ["LID"] if "LID" in os.environ else get_timestamp("%y%m%d-%H%M%S")

    #if we are in multi-coupler mode the number of instances of mediator will be the max
    # of any NINST_* value
    maxinst = 1
    if case.get_value("MULTI_DRIVER"):
        maxinst = case.get_value("NINST_MAX")
        multi_driver = True

    nuopc_config_file = os.path.join(confdir, "nuopc.runconfig")
    for model in case.get_values("COMP_CLASSES"):
        model = model.lower()
        with NamelistGenerator(case, definition_file) as nmlgen:
            config = {}
            config['component'] = model
            entries = nmlgen.init_defaults(infiles, config, skip_entry_loop=True)
            if maxinst == 1 and model != 'cpl' and not multi_driver:
                inst_count = case.get_value("NINST_" + model.upper())
            else:
                inst_count = maxinst

            inst_string = ""
            inst_index = 1
            while inst_index <= inst_count:
                # determine instance string
                if inst_count > 1:
                    inst_string = '_{:04d}'.format(inst_index)

                # Write out just the pio_inparm to the output file
                for entry in entries:
                    nmlgen.add_default(entry)

                if model == "cpl":
                    modelio_file = "med_modelio.nml" + inst_string
                else:
                    modelio_file = model + "_modelio.nml" + inst_string
                nmlgen.write_nuopc_modelio_file(os.path.join(confdir, modelio_file))

                # Output the following to nuopc.runconfig
                moddiro = case.get_value('RUNDIR')
                if model == 'cpl':
                    logfile = 'med' + inst_string + ".log." + str(lid)
                else:
                    logfile = model + inst_string + ".log." + str(lid)

                with open(nuopc_config_file, 'a') as outfile:
                    if model == 'cpl':
                        name = "MED"
                    else:
                        name = model.upper()
                    if inst_string:
                        outfile.write("{}_modelio{}::\n".format(name,inst_string))
                    else:
                        outfile.write("{}_modelio::\n".format(name))
                    outfile.write("     {}{}{}".format("diro = ", moddiro,"\n"))
                    outfile.write("     {}{}{}".format("logfile = ", logfile,"\n"))
                    outfile.write("::\n\n")

                    # also write out a driver log file 
                    if model == 'cpl':
                        name = "DRV"
                        logfile = 'drv' + inst_string + ".log." + str(lid)
                        if inst_string:
                            outfile.write("{}_modelio{}::\n".format(name,inst_string))
                        else:
                            outfile.write("{}_modelio::\n".format(name))
                        outfile.write("     {}{}{}".format("diro = ", moddiro,"\n"))
                        outfile.write("     {}{}{}".format("logfile = ", logfile,"\n"))
                        outfile.write("::\n\n")

                inst_index = inst_index + 1


###############################################################################
def buildnml(case, caseroot, component):
###############################################################################
    if component != "drv":
        raise AttributeError

    confdir = os.path.join(case.get_value("CASEBUILD"), "cplconf")
    if not os.path.isdir(confdir):
        os.makedirs(confdir)

    # NOTE: User definition *replaces* existing definition.
    # TODO: Append instead of replace?
    user_xml_dir = os.path.join(caseroot, "SourceMods", "src.drv")

    expect (os.path.isdir(user_xml_dir),
            "user_xml_dir %s does not exist " %user_xml_dir)

    files = Files(comp_interface="nuopc")

    # TODO: to get the right attributes of COMP_ROOT_DIR_CPL in evaluating definition_file - need
    # to do the following first - this needs to be changed so that the following two lines are not needed!
    comp_root_dir_cpl = files.get_value( "COMP_ROOT_DIR_CPL",{"component":"drv-nuopc"}, resolved=False)
    files.set_value("COMP_ROOT_DIR_CPL", comp_root_dir_cpl)

    definition_file = [files.get_value("NAMELIST_DEFINITION_FILE", {"component": "drv-nuopc"})]
    fd_dir = os.path.dirname(definition_file[0])
    user_definition = os.path.join(user_xml_dir, "namelist_definition_drv.xml")
    if os.path.isfile(user_definition):
        definition_file = [user_definition]

    # create the namelist generator object - independent of instance
    nmlgen = NamelistGenerator(case, definition_file)

    # create cplconf/namelist
    infile_text = ""

    # determine infile list for nmlgen
    user_nl_file = os.path.join(caseroot, "user_nl_cpl")
    namelist_infile = os.path.join(confdir, "namelist_infile")
    create_namelist_infile(case, user_nl_file, namelist_infile, infile_text)
    infile = [namelist_infile]

    # create the files nuopc.runconfig, nuopc.runseq, drv_in and drv_flds_in
    _create_drv_namelists(case, infile, confdir, nmlgen, files)

    # create the files comp_modelio.nml where comp = [atm, lnd...]
    _create_component_modelio_namelists(confdir, case, files)

    # set rundir
    rundir = case.get_value("RUNDIR")

    # copy nuopc.runconfig to rundir
    shutil.copy(os.path.join(confdir,"drv_in"), rundir)
    shutil.copy(os.path.join(confdir,"nuopc.runconfig"), rundir)

    # copy drv_flds_in to rundir
    drv_flds_in = os.path.join(caseroot, "CaseDocs", "drv_flds_in")
    if os.path.isfile(drv_flds_in):
        shutil.copy(drv_flds_in, rundir)

    # copy all *modelio* files to rundir
    for filename in glob.glob(os.path.join(confdir, "*modelio*")):
        shutil.copy(filename, rundir)

    # copy fd_cesm.yaml to rundir
    cimeroot = case.get_value("CIMEROOT")
    fd_dir = os.path.join(cimeroot, "src","drivers","nuopc","mediator")
    coupling_mode = case.get_value('COUPLING_MODE')
    if coupling_mode == 'cesm':
        filename = os.path.join(fd_dir,"fd_cesm.yaml")
    else:
        filename = os.path.join(fd_dir,"fd_nems.yaml")
    shutil.copy(filename, os.path.join(rundir, "fd.yaml"))

###############################################################################
def _main_func():
    caseroot = parse_input(sys.argv)

    with Case(caseroot) as case:
        buildnml(case, caseroot, "drv")

if __name__ == "__main__":
    _main_func()
